
import os
from datasets import load_dataset

def get_dataset_by_file():
    dataset = load_dataset("csv", data_files="data/ChnSentiCorp_htl_all.csv", split="train")
    dataset = dataset.filter(lambda x: x["review"] is not None)
    return dataset


# 方法1: 保存原始tensor数据
def save_tensor_dataset(dataset, save_dir="./data"):
    """保存TensorDataset的原始tensor数据"""
    os.makedirs(save_dir, exist_ok=True)

    # 保存特征和标签
    dataset.save_to_disk(os.path.join(save_dir, "dataset.pt"))
    print(f"数据已保存到 {save_dir}")

if __name__ == "__main__":
    dataset = get_dataset_by_file()
    save_tensor_dataset(dataset)
